import { A_Const, A_Expr, ColumnRef, Node, RawStmt } from 'libpg-query/wasm'

export function isQueryStatement(stmt: RawStmt) {
  return stmt.stmt && unwrapQueryStatement(stmt.stmt) !== undefined
}

export function isSeedStatement(stmt: RawStmt) {
  return stmt.stmt && unwrapSeedStatement(stmt.stmt) !== undefined
}

export function isMigrationStatement(stmt: RawStmt) {
  return stmt.stmt && unwrapMigrationStatement(stmt.stmt) !== undefined
}

export function unwrapQueryStatement(node: Node) {
  // Select statements are considered queries if they don't insert into another table
  const selectStatement = unwrapNode(node, 'SelectStmt')

  if (selectStatement && !selectStatement.intoClause) {
    return selectStatement
  }
}

export function unwrapSeedStatement(node: Node) {
  // Insert, update, and delete statements are always considered seeds
  const statement = unwrapOneOfNode(node, ['InsertStmt', 'UpdateStmt', 'DeleteStmt', 'MergeStmt'])

  if (statement) {
    return statement
  }

  // Select statements are considered seeds if they insert into another table
  const selectStatement = unwrapNode(node, 'SelectStmt')

  if (selectStatement?.intoClause) {
    return selectStatement
  }
}

export function unwrapMigrationStatement(node: Node) {
  const statement = unwrapOneOfNode(node, [
    'ReturnStmt',
    'PLAssignStmt',
    'AlterTableStmt',
    'AlterDomainStmt',
    'SetOperationStmt',
    'GrantStmt',
    'GrantRoleStmt',
    'AlterDefaultPrivilegesStmt',
    'ClosePortalStmt',
    'ClusterStmt',
    'CopyStmt',
    'CreateStmt',
    'DefineStmt',
    'DropStmt',
    'TruncateStmt',
    'CommentStmt',
    'FetchStmt',
    'IndexStmt',
    'CreateFunctionStmt',
    'AlterFunctionStmt',
    'DoStmt',
    'RenameStmt',
    'RuleStmt',
    'NotifyStmt',
    'ListenStmt',
    'UnlistenStmt',
    'TransactionStmt',
    'ViewStmt',
    'LoadStmt',
    'CreateDomainStmt',
    'CreatedbStmt',
    'DropdbStmt',
    'VacuumStmt',
    'ExplainStmt',
    'CreateTableAsStmt',
    'CreateSeqStmt',
    'AlterSeqStmt',
    'VariableSetStmt',
    'VariableShowStmt',
    'DiscardStmt',
    'CreateTrigStmt',
    'CreatePLangStmt',
    'CreateRoleStmt',
    'AlterRoleStmt',
    'DropRoleStmt',
    'LockStmt',
    'ConstraintsSetStmt',
    'ReindexStmt',
    'CheckPointStmt',
    'CreateSchemaStmt',
    'AlterDatabaseStmt',
    'AlterDatabaseRefreshCollStmt',
    'AlterDatabaseSetStmt',
    'AlterRoleSetStmt',
    'CreateConversionStmt',
    'CreateCastStmt',
    'CreateOpClassStmt',
    'CreateOpFamilyStmt',
    'AlterOpFamilyStmt',
    'PrepareStmt',
    'ExecuteStmt',
    'DeallocateStmt',
    'DeclareCursorStmt',
    'CreateTableSpaceStmt',
    'DropTableSpaceStmt',
    'AlterObjectDependsStmt',
    'AlterObjectSchemaStmt',
    'AlterOwnerStmt',
    'AlterOperatorStmt',
    'AlterTypeStmt',
    'DropOwnedStmt',
    'ReassignOwnedStmt',
    'CompositeTypeStmt',
    'CreateEnumStmt',
    'CreateRangeStmt',
    'AlterEnumStmt',
    'AlterTSDictionaryStmt',
    'AlterTSConfigurationStmt',
    'CreateFdwStmt',
    'AlterFdwStmt',
    'CreateForeignServerStmt',
    'AlterForeignServerStmt',
    'CreateUserMappingStmt',
    'AlterUserMappingStmt',
    'DropUserMappingStmt',
    'AlterTableSpaceOptionsStmt',
    'AlterTableMoveAllStmt',
    'SecLabelStmt',
    'CreateForeignTableStmt',
    'ImportForeignSchemaStmt',
    'CreateExtensionStmt',
    'AlterExtensionStmt',
    'AlterExtensionContentsStmt',
    'CreateEventTrigStmt',
    'AlterEventTrigStmt',
    'RefreshMatViewStmt',
    'ReplicaIdentityStmt',
    'AlterSystemStmt',
    'CreatePolicyStmt',
    'AlterPolicyStmt',
    'CreateTransformStmt',
    'CreateAmStmt',
    'CreatePublicationStmt',
    'AlterPublicationStmt',
    'CreateSubscriptionStmt',
    'AlterSubscriptionStmt',
    'DropSubscriptionStmt',
    'CreateStatsStmt',
    'AlterCollationStmt',
    'CallStmt',
    'AlterStatsStmt',
  ])

  if (statement) {
    return statement
  }
}

export type PolicyInfo = {
  name: string
  relation: string
  command?: string
  roles: string[]
  usingNode?: Node
  withCheckNode?: Node
}

/**
 * Extracts keys from a union type.
 */
export type ExtractKeys<T> = T extends T ? keyof T : never

/**
 * Filters strings ending in the specified suffix.
 */
export type StringWithSuffix<T extends string, S extends string> = T extends `${infer _Prefix}${S}`
  ? T
  : never

/**
 * Extracts all nodes who's keys end in the specified suffix.
 */
export type NodeWithSuffix<S extends string> = Node extends infer U
  ? U extends {
      [K in keyof U]: K extends `${infer _Prefix}${S}` ? U : never
    }[keyof U]
    ? U[keyof U]
    : never
  : never

/**
 * Every Node type.
 */
export type NodeType = ExtractKeys<Node>

/**
 * Every statement Node type.
 */
export type StatementType = StringWithSuffix<NodeType, 'Stmt'>

/**
 * Every statement Node.
 */
export type Statement = NodeWithSuffix<'Stmt'>

/**
 * Unwraps a Node to get its underlying value.
 */
export type NodeValue<T extends Node, U extends NodeType> = T extends Record<U, infer V> ? V : never

export class AssertionError extends Error {
  constructor(message: string) {
    super(message)

    // Pop the top line from the stack trace so that
    // debug tools will reference the code calling the
    // assertion function and not this error within it
    if (this.stack) {
      // Capture the current stack trace and split it into lines
      const stackLines = this.stack.split('\n')

      // Remove the second line which is the current constructor
      stackLines.splice(1, 1)

      // Reassign the modified stack trace back to the error object
      this.stack = stackLines.join('\n')
    }
  }
}

/**
 * Asserts that a value is defined.
 *
 * Useful for type narrowing.
 */
export function assertDefined<T>(value: T | undefined, errorMessage: string): asserts value is T {
  if (value === undefined) {
    throw new AssertionError(errorMessage)
  }
}

/**
 * Asserts that a `Node` is a specific type.
 *
 * Useful for type narrowing.
 */
export function assertNodeType<T extends Node>(
  node: Node,
  type: ExtractKeys<T>,
  errorMessage: string
): asserts node is T {
  if (!(type in node)) {
    throw new AssertionError(errorMessage)
  }
}

/**
 * Asserts that a `Node` is a specific type and
 * unwraps its underlying value.
 *
 * @returns The unwrapped `Node` value.
 * @throws If `node` is not of type `type`.
 */
export function assertAndUnwrapNode<T extends Node, U extends NodeType>(
  node: Node,
  type: U,
  errorMessage: string
): NodeValue<T, U> {
  if (!(type in node)) {
    throw new AssertionError(errorMessage)
  }
  return (node as any)[type]
}

/**
 * Unwraps a `Node`'s underlying value.
 *
 * @returns The unwrapped `Node` value or `undefined` if
 * the node is not of type `type`.
 */
export function unwrapNode<T extends Node, U extends NodeType>(
  node: Node,
  type: U
): NodeValue<T, U> | undefined {
  if (!(type in node)) {
    return undefined
  }
  return (node as any)[type]
}

/**
 * Unwraps a `Node`'s underlying value.
 *
 * @returns The unwrapped `Node` value or `undefined` if
 * the node is not one of the node types in `types`.
 */
export function unwrapOneOfNode<T extends Node, U extends readonly [NodeType, ...NodeType[]]>(
  node: Node,
  types: U
): { [K in keyof U]: NodeValue<T, U[K]> }[number] | undefined {
  const nodeType = types.find((type) => unwrapNode(node, type))

  if (nodeType) {
    return unwrapNode(node, nodeType)
  }
}

/**
 * Asserts that either the left or right side of the
 * expression is processed through `fn` without throwing
 * any errors.
 *
 * If both sides throw errors, the assertion fails and the
 * error from the left side will be thrown.
 */
export function assertEitherSideOfExpression<U>(
  expression: A_Expr,
  fn: (node: Node, side: 'left' | 'right') => U
): U {
  assertDefined(expression.lexpr, 'Expected left side of expression to exist')
  assertDefined(expression.rexpr, 'Expected right side of expression to exist')

  try {
    return fn(expression.lexpr, 'left')
  } catch (leftError) {
    try {
      return fn(expression.rexpr, 'right')
    } catch (rightError) {
      throw leftError
    }
  }
}
/**
 * Asserts that both sides of the expression are processed
 * without throwing any errors.
 *
 * Order doesn't matter. As long as `firstFn` and `secondFn`
 * pass separately on either side of the expression, the assertion
 * will pass. Otherwise if `firstFn` and `secondFn` both fail
 * after trying on both sides separately, the assertion will fail.
 */
export function assertEachSideOfExpression<U>(
  expression: A_Expr,
  firstFn: (node: Node) => void,
  secondFn: (node: Node) => void
): void {
  assertDefined(expression.lexpr, 'Expected left side of expression to exist')
  assertDefined(expression.rexpr, 'Expected right side of expression to exist')

  let firstSide: Node
  let secondSide: Node

  try {
    // Try `firstFn` on the left first
    firstSide = expression.lexpr
    secondSide = expression.rexpr
    firstFn(firstSide)
  } catch (firstError) {
    // Otherwise try `firstFn` on the right
    firstSide = expression.rexpr
    secondSide = expression.lexpr
    try {
      firstFn(firstSide)
    } catch (secondError) {
      // `firstFn` failed on both sides, so we
      // need to throw an error not matter what

      // Perform one more test using `secondFn`
      // to help determine which error to show
      // for `firstFn`
      try {
        secondFn(secondSide)
      } catch (_) {
        throw firstError
      }
      throw secondError
    }
  }

  try {
    // `firstFn` passed on one of the sides, so
    // try `secondFn` on the opposite side
    secondFn(secondSide)
  } catch (err) {
    // `secondFn` failed, so we need to throw an error
    throw err
  }
}

export function renderTargets<T>(targets: Node[], renderTarget: (node: Node) => T): T[] {
  return targets.map((node) => {
    const target = assertAndUnwrapNode(
      node,
      'ResTarget',
      'Expected target list to contain ResTargets'
    )

    assertDefined(target.val, 'Expected ResTarget to have a val')
    return renderTarget(target.val)
  })
}

export function renderColumn(column: ColumnRef) {
  assertDefined(column.fields, 'Expected column to have fields')
  return renderFields(column.fields)
}

export function assertAndRenderColumn(node: Node, errorMessage: string) {
  const column = assertAndUnwrapNode(node, 'ColumnRef', errorMessage)
  return renderColumn(column)
}

export function renderJsonExpression(expression: A_Expr): string {
  assertDefined(expression.name, 'Expected expression to have an operator')

  if (expression.name.length > 1) {
    throw new AssertionError('Only one JSON operator supported per expression')
  }

  const [name] = expression.name
  const operatorString = assertAndUnwrapNode(
    name,
    'String',
    'Expected JSON operator to be a string'
  )
  assertDefined(operatorString.sval, 'PG string expected to have an sval')
  const operator = operatorString.sval

  if (!['->', '->>'].includes(operator)) {
    throw new AssertionError(`Invalid JSON operator ${operator}`)
  }

  assertDefined(expression.lexpr, 'Expected JSON expression to have a left-side component')
  assertDefined(expression.rexpr, 'Expected JSON expression to have a right-side component')

  let left: string | number
  let right: string | number

  const leftConstant = unwrapNode(expression.lexpr, 'A_Const')
  const leftColumn = unwrapNode(expression.lexpr, 'ColumnRef')
  const leftFuncCall = unwrapNode(expression.lexpr, 'FuncCall')
  const leftExpression = unwrapNode(expression.lexpr, 'A_Expr')

  if (leftConstant) {
    // JSON path cannot contain a float
    if ('fval' in leftConstant) {
      throw new AssertionError('Invalid JSON path: Expression cannot contain a float')
    }
    left = `'${parseConstant(leftConstant)}'`
  } else if (leftColumn) {
    left = renderColumn(leftColumn)
  } else if (leftFuncCall) {
    assertDefined(leftFuncCall.funcname, 'Expected function call to have a name')
    const functionName = renderFields(leftFuncCall.funcname)
    left = `${functionName}()`
  } else if (leftExpression) {
    left = renderJsonExpression(leftExpression)
  } else {
    throw new AssertionError('Invalid JSON path')
  }

  const rightConstant = unwrapNode(expression.rexpr, 'A_Const')

  if (rightConstant) {
    // JSON path cannot contain a float
    if ('fval' in rightConstant) {
      throw new AssertionError('Invalid JSON path: Expression cannot contain a float')
    }
    right = `'${parseConstant(rightConstant)}'`
  } else {
    throw new AssertionError('Invalid JSON path')
  }

  return `${left}${operator}${right}`
}

export function renderFields(fields: Node[]) {
  const nameSegments = fields
    .map((field) => {
      const stringField = unwrapNode(field, 'String')
      const starField = unwrapNode(field, 'A_Star')

      if (stringField !== undefined) {
        return stringField.sval
      } else if (starField !== undefined) {
        return '*'
      } else {
        const [internalType] = Object.keys(field)
        throw new Error(`Unsupported internal type '${internalType}' for fields`)
      }
    })
    .filter((name): name is string => name !== undefined)

  return nameSegments.join('.')
}

export function parseConstant(constant: A_Const) {
  if ('sval' in constant) {
    return constant.sval?.sval ?? ''
  } else if ('ival' in constant) {
    // The PG parser turns 0 into undefined, so convert it back here
    return constant.ival?.ival ?? 0
  } else if ('fval' in constant) {
    return constant.fval?.fval ? parseFloat(constant.fval.fval) : 0
  } else {
    throw new AssertionError(`Constant values must be a string, integer, or float`)
  }
}
